#include "model_mdl.h"
#include <gflags/gflags.h>
#include <iostream>

DECLARE_bool(share_param_mem);

using namespace lar;

ModelMdl::ModelMdl(const std::string& path) : model_path(path) {
    mgb_log("creat mdl model use XPU as default comp node");
    m_load_config.comp_graph = mgb::ComputingGraph::make();
    m_load_config.comp_graph->options().graph_opt_level = 0;
    testcase_num = 0;
}

void ModelMdl::load_model() {
    //! read dump file
    if (share_model_mem) {
        mgb_log("enable share model memory");
        FILE* fin = fopen(model_path.c_str(), "rb");
        mgb_assert(fin, "failed to open %s: %s", model_path.c_str(), strerror(errno));
        fseek(fin, 0, SEEK_END);
        size_t size = ftell(fin);
        fseek(fin, 0, SEEK_SET);

        void* ptr = malloc(size);
        std::shared_ptr<void> buf{ptr, free};
        auto nr = fread(buf.get(), 1, size, fin);
        mgb_assert(nr == size, "read model file failed");
        fclose(fin);

        m_model_file = mgb::serialization::InputFile::make_mem_proxy(buf, size);
    } else {
        m_model_file = mgb::serialization::InputFile::make_fs(model_path.c_str());
    }

    //! get dump_with_testcase model testcase number
    char magic[8];
    m_model_file->read(magic, sizeof(magic));
    if (strncmp(magic, "mgbtest0", 8)) {
        m_model_file->rewind();
    } else {
        m_model_file->read(&testcase_num, sizeof(testcase_num));
    }

    m_format =
            mgb::serialization::GraphLoader::identify_graph_dump_format(*m_model_file);
    mgb_assert(
            m_format.valid(),
            "invalid format, please make sure model is dumped by GraphDumper");

    //! load computing graph of model
    m_loader = mgb::serialization::GraphLoader::make(
            std::move(m_model_file), m_format.val());
    m_load_result = m_loader->load(m_load_config, false);
    m_load_config.comp_graph.reset();

    // get testcase input generated by dump_with_testcase.py
    if (testcase_num) {
        for (auto&& i : m_load_result.tensor_map) {
            test_input_tensors.emplace_back(i.first, i.second.get());
        }
        std::sort(test_input_tensors.begin(), test_input_tensors.end());
    }
    // initialize output callback
    for (size_t i = 0; i < m_load_result.output_var_list.size(); i++) {
        mgb::ComputingGraph::Callback cb;
        m_callbacks.push_back(cb);
    }
}

void ModelMdl::make_output_spec() {
    for (size_t i = 0; i < m_load_result.output_var_list.size(); i++) {
        auto item = m_load_result.output_var_list[i];
        m_output_spec.emplace_back(item, std::move(m_callbacks[i]));
    }

    m_asyc_exec = m_load_result.graph_compile(m_output_spec);
    auto new_output_vars = m_asyc_exec->get_output_vars();
    mgb::cg::SymbolVarArray symbol_var_array;
    symbol_var_array.reserve(new_output_vars.size());
    for (auto output_var : new_output_vars) {
        symbol_var_array.emplace_back(output_var);
    }
    m_load_result.output_var_list = symbol_var_array;
}

std::shared_ptr<mgb::serialization::GraphLoader>& ModelMdl::reset_loader(
        std::unique_ptr<mgb::serialization::InputFile> input_file) {
    if (input_file) {
        m_loader = mgb::serialization::GraphLoader::make(
                std::move(input_file), m_loader->format());
    } else {
        m_loader = mgb::serialization::GraphLoader::make(
                m_loader->reset_file(), m_loader->format());
    }
    return m_loader;
}

void ModelMdl::run_model() {
    mgb_assert(
            m_asyc_exec != nullptr,
            "empty asychronous function to execute after graph compiled");
    m_asyc_exec->execute();
}

void ModelMdl::wait() {
    m_asyc_exec->wait();
}

#if MGB_ENABLE_JSON
std::shared_ptr<mgb::json::Object> ModelMdl::get_io_info() {
    std::shared_ptr<mgb::json::Array> inputs = mgb::json::Array::make();
    std::shared_ptr<mgb::json::Array> outputs = mgb::json::Array::make();
    auto get_dtype = [&](megdnn::DType data_type) {
        std::map<megdnn::DTypeEnum, std::string> type_map = {
                {mgb::dtype::Float32().enumv(), "float32"},
                {mgb::dtype::Int32().enumv(), "int32"},
                {mgb::dtype::Int16().enumv(), "int16"},
                {mgb::dtype::Uint16().enumv(), "uint16"},
                {mgb::dtype::Int8().enumv(), "int8"},
                {mgb::dtype::Uint8().enumv(), "uint8"}};
        return type_map[data_type.enumv()];
    };
    auto make_shape = [](mgb::TensorShape& shape_) {
        std::vector<std::pair<mgb::json::String, std::shared_ptr<mgb::json::Value>>>
                shape;
        for (size_t i = 0; i < shape_.ndim; ++i) {
            std::string lable = "dim";
            lable += std::to_string(shape_.ndim - i - 1);
            shape.push_back(
                    {mgb::json::String(lable),
                     mgb::json::NumberInt::make(shape_[shape_.ndim - i - 1])});
        }
        return shape;
    };
    for (auto&& i : m_load_result.tensor_map) {
        std::vector<std::pair<mgb::json::String, std::shared_ptr<mgb::json::Value>>>
                json_inp;
        auto shape_ = i.second->shape();
        json_inp.push_back(
                {mgb::json::String("shape"),
                 mgb::json::Object::make(make_shape(shape_))});
        json_inp.push_back(
                {mgb::json::String("dtype"),
                 mgb::json::String::make(get_dtype(i.second->dtype()))});
        json_inp.push_back(
                {mgb::json::String("name"), mgb::json::String::make(i.first)});
        inputs->add(mgb::json::Object::make(json_inp));
    }

    for (auto&& i : m_load_result.output_var_list) {
        std::vector<std::pair<mgb::json::String, std::shared_ptr<mgb::json::Value>>>
                json_out;
        auto shape_ = i.shape();
        json_out.push_back(
                {mgb::json::String("shape"),
                 mgb::json::Object::make(make_shape(shape_))});
        json_out.push_back(
                {mgb::json::String("dtype"),
                 mgb::json::String::make(get_dtype(i.dtype()))});

        json_out.push_back(
                {mgb::json::String("name"), mgb::json::String::make(i.node()->name())});
        outputs->add(mgb::json::Object::make(json_out));
    }
    return mgb::json::Object::make(
            {{"IO",
              mgb::json::Object::make({{"outputs", outputs}, {"inputs", inputs}})}});
}
#endif

std::vector<uint8_t> ModelMdl::get_model_data() {
    std::vector<uint8_t> out_data;
    auto out_file = mgb::serialization::OutputFile::make_vector_proxy(&out_data);
    using DumpConfig = mgb::serialization::GraphDumper::DumpConfig;
    DumpConfig config{1, false, false};
    auto dumper =
            mgb::serialization::GraphDumper::make(std::move(out_file), m_format.val());
    dumper->dump(m_load_result.output_var_list, config);
    return out_data;
}

void ModelMdl::update_io() {
    //! update output varlist when input shape maybe change(some pass excution
    //! time depends on the shape of init input)
    mgb::thin_hash_table::ThinHashMap<mgb::cg::SymbolVar, mgb::cg::SymbolVar> varmap;
    auto&& network = m_load_result;
    std::unordered_map<void*, std::string> tensor_name_map;
    for (auto& input : network.tensor_map) {
        tensor_name_map.insert({input.second->raw_ptr(), input.first});
    }
    mgb::cg::DepOprIter dep([&](mgb::cg::OperatorNodeBase* opr) {
        if (auto h2d = opr->try_cast_final<mgb::opr::Host2DeviceCopy>()) {
            if (tensor_name_map.find(h2d->host_data()->raw_ptr()) !=
                tensor_name_map.end()) {
                //! make new h2d opr with new host tensor shape
                std::string name = tensor_name_map[h2d->host_data()->raw_ptr()];
                std::shared_ptr<mgb::HostTensorND> new_tensor =
                        std::make_shared<mgb::HostTensorND>();
                new_tensor->copy_from(*h2d->host_data());

                auto h2d_opr = mgb::opr::Host2DeviceCopy::make(
                        *h2d->owner_graph(), new_tensor, h2d->param(), h2d->config());
                //! rename new h2d with given name
                h2d_opr.node()->owner_opr()->name(name);
                varmap[h2d->output(0)] = h2d_opr;
            }
        }
    });
    //! get replace var map
    for (auto&& i : network.output_var_list)
        dep.add(i);
    //! replace new h2d and update related var shape
    if (!varmap.empty()) {
        auto output_vars = mgb::cg::replace_vars(network.output_var_list, varmap);
        network.output_var_list = output_vars;
    }
}